using LinearAlgebra

function interpolate_y_internal(x, xq, y)
    xi = 1
    x_low = x[1]
    x_high = x[2]

    yq = similar(xq)

    for xqi_cur ∈ 1:length(xq)
        xq_cur = xq[xqi_cur]
        while xi < length(x) - 1
            (x_high ≥ xq_cur) && break
            xi += 1
            x_low = x_high
            x_high = x[xi + 1]
        end

        xqpi_cur = (x_high - xq_cur) / (x_high - x_low)
        yq[xqi_cur] = xqpi_cur * y[xi] + (1 - xqpi_cur) * y[xi + 1]
    end

    return yq
end

function interpolate_y(x, xq, y)
    xq_size_original = size(xq)

    ndims(x)==1 && (x = reshape(x, (1, 1, size(x)...)))
    ndims(x)==2 && (x = reshape(x, (1, size(x)...)))
    ndims(xq)==1 && (xq = reshape(xq, (1, 1, size(xq)...)))
    ndims(xq)==2 && (xq = reshape(xq, (1, size(xq)...)))
    ndims(y)==1 && (y = reshape(y, 1, 1, size(y)...))
    ndims(y)==2 && (y = reshape(y, 1, size(y)...))

    maxdim1 = max(size(x, 1), size(xq, 1), size(y, 1))
    (size(x, 1) ≠ maxdim1) && (x = repeat(x, maxdim1))
    (size(xq, 1) ≠ maxdim1) && (xq = repeat(xq, maxdim1))
    (size(y, 1) ≠ maxdim1) && (y = repeat(y, maxdim1))

    maxdim2 = max(size(x, 2), size(xq, 2), size(y, 2))
    (size(x, 2) ≠ maxdim2) && (x = repeat(x, 1, maxdim2))
    (size(xq, 2) ≠ maxdim2) && (xq = repeat(xq, 1, maxdim2))
    (size(y, 2) ≠ maxdim2) && (y = repeat(y, 1, maxdim2))

    yq = zeros(maxdim1, maxdim2, xq_size_original[end])

    for dim1 ∈ 1:maxdim1
        for dim2 ∈ 1:maxdim2
            yq[dim1, dim2, :] = interpolate_y_internal(x[dim1, dim2, :], xq[dim1, dim2, :], y[dim1, dim2, :])
        end
    end

    yq_return = yq
    (maxdim1==1) && (yq_return = yq[1, :, :])
    (maxdim2==1) && (yq_return = yq[:, 1, :])
    (maxdim1==1 && maxdim2==1) && (yq_return = yq[1, 1, :])

    return yq_return
end


function interpolate_coord_internal(x, xq)
    nxq, nx = size(xq, 1), size(xq, 1)

    xi = 1
    x_low = x[1]
    x_high = x[2]

    xqi, xqpi = similar(xq, Int), similar(xq)

    for xqi_cur ∈ 1:length(xq)
        xq_cur = xq[xqi_cur]
        while xi < nx - 1
            (x_high ≥ xq_cur) && break
            xi += 1
            x_low = x_high
            x_high = x[xi + 1]
        end

        xqpi[xqi_cur] = (x_high - xq_cur) / (x_high - x_low)
        xqi[xqi_cur] = xi
    end

    return xqi, xqpi
end


function interpolate_coord(x, xq)
    xq_size_original = size(xq)

    ndims(x)==1 && (x = reshape(x, (1, 1, size(x)...)))
    ndims(x)==2 && (x = reshape(x, (1, size(x)...)))
    ndims(xq)==1 && (xq = reshape(xq, (1, 1, size(xq)...)))
    ndims(xq)==2 && (xq = reshape(xq, (1, size(xq)...)))

    maxdim1 = max(size(x, 1), size(xq, 1))
    (size(x, 1) ≠ maxdim1) && (x = repeat(x, maxdim1))
    (size(xq, 1) ≠ maxdim1) && (xq = repeat(xq, maxdim1))

    maxdim2 = max(size(x, 2), size(xq, 2))
    (size(x, 2) ≠ maxdim2) && (x = repeat(x, 1, maxdim2))
    (size(xq, 2) ≠ maxdim2) && (xq = repeat(xq, 1, maxdim2))

    xqi = zeros(Int, maxdim1, maxdim2, xq_size_original[end])
    xqpi = zeros(maxdim1, maxdim2, xq_size_original[end])

    for dim1 ∈ 1:maxdim1
        for dim2 ∈ 1:maxdim2
            xqi[dim1, dim2, :], xqpi[dim1, dim2, :] = interpolate_coord_internal(x[dim1, dim2, :], xq[dim1, dim2, :])
        end
    end

    xqi_return, xqpi_return = xqi, xqpi
    (maxdim1==1) && ((xqi_return, xqpi_return) = (xqi[1, :, :], xqpi[1, :, :]))
    (maxdim2==1) && ((xqi_return, xqpi_return) = (xqi[:, 1, :], xqpi[:, 1, :]))
    (maxdim1==1 && maxdim2==1) && ((xqi_return, xqpi_return) = (xqi[1, 1, :], xqpi[1, 1, :]))

    return xqi_return, xqpi_return
end

function apply_coord_internal(x_i, x_pi, y)
    yq = similar(x_pi)
    for iq ∈ 1:length(x_i)
        y_low = y[Int(x_i[iq])]
        y_high = y[Int(x_i[iq]+1)]
        yq[iq] = x_pi[iq] * y_low + (1 - x_pi[iq]) * y_high
    end
    return yq
end

function apply_coord(x_i, x_pi, y)
    x_i_size_original = size(x_i)

    ndims(x_i)==1 && (x_i = reshape(x_i, (1, 1, size(x_i)...)))
    ndims(x_i)==2 && (x_i = reshape(x_i, (1, size(x_i)...)))
    ndims(x_pi)==1 && (x_pi = reshape(x_pi, (1, 1, size(x_pi)...)))
    ndims(x_pi)==2 && (x_pi = reshape(x_pi, (1, size(x_pi)...)))
    ndims(y)==1 && (y = reshape(y, 1, 1, size(y)...))
    ndims(y)==2 && (y = reshape(y, 1, size(y)...))

    maxdim1 = max(size(x_i, 1), size(x_pi, 1), size(y, 1))
    (size(x_i, 1) ≠ maxdim1) && (x_i = repeat(x_i, maxdim1))
    (size(x_pi, 1) ≠ maxdim1) && (x_pi = repeat(x_pi, maxdim1))
    (size(y, 1) ≠ maxdim1) && (y = repeat(y, maxdim1))

    maxdim2 = max(size(x_i, 2), size(x_pi, 2), size(y, 2))
    (size(x_i, 2) ≠ maxdim2) && (x_i = repeat(x_i, 1, maxdim2))
    (size(x_pi, 2) ≠ maxdim2) && (x_pi = repeat(x_pi, 1, maxdim2))
    (size(y, 2) ≠ maxdim2) && (y = repeat(y, 1, maxdim2))

    yq = zeros(maxdim1, maxdim2, x_i_size_original[end])

    for dim1 ∈ 1:maxdim1
        for dim2 ∈ 1:maxdim2
            yq[dim1, dim2, :] = apply_coord_internal(x_i[dim1, dim2, :], x_pi[dim1, dim2, :], y[dim1, dim2, :])
        end
    end

    yq_return = yq
    (maxdim1==1) && (yq_return = yq[1, :, :])
    (maxdim2==1) && (yq_return = yq[:, 1, :])
    (maxdim1==1 && maxdim2==1) && (yq_return = yq[1, 1, :])

    return yq_return
end

function interpolate_coord_robust(x, xq; check_increasing=true)
    if ndims(x) != 1
        throw(ArgumentError("Data input to interpolate_coord_robust must have exactly one dimension"))
    end

    if check_increasing && any(x[1:end-1] .>= x[2:end])
        throw(ArgumentError("Data input to interpolate_coord_robust must be strictly increasing"))
    end

    if ndims(xq) == 1
        return interpolate_coord_robust_vector(x, xq)
    else
        i, _pi = interpolate_coord_robust_vector(x, vec(xq))
        return reshape(i, size(xq)), reshape(_pi, size(xq))
    end
end

function interpolate_coord_robust_vector(x, xq)
    n = length(x)
    nq = length(xq)
    xqi = similar(xq, Int)
    xqpi = similar(xq)

    for iq in 1:nq
        if xq[iq] < x[1]
            ilow = 1
        elseif xq[iq] > x[end-1]
            ilow = n - 1
        else 
            ihigh = n
            ilow = 1
            while ihigh - ilow > 1
                imid = (ihigh + ilow) ÷ 2 #floor division!
                if xq[iq] > x[imid]
                    ilow = imid
                else
                    ihigh = imid
                end
            end
        end

        xqi[iq] = ilow
        xqpi[iq] = (x[ilow+1] - xq[iq]) / (x[ilow+1] - x[ilow])
    end
    return xqi, xqpi
end

# STILL NEED DEBUG
function interpolate_coord_njit(x::Vector{Float64}, xq::Vector{Float64})
    nx = length(x)
    nxq = length(xq)
    xqi = Vector{UInt32}(undef, nxq)
    xqpi = Vector{Float64}(undef, nxq)

    xi = 1
    x_low = x[1]
    x_high = x[2]

    for xqi_cur in 1:nxq
        xq_cur = xq[xqi_cur]

        while xi < nx - 1
            if x_high >= xq_cur
                break
            end
            xi += 1
            x_low = x_high
            x_high = x[xi+1]
        end

        xqpi[xqi_cur] = (x_high - xq_cur) / (x_high - x_low)
        xqi[xqi_cur] = xi
    end
    return xqi, xqpi
end

# STILL NEED DEBUG
function apply_coord_njit(x_i::Vector{UInt32}, x_pi::Vector{Float64}, y::Vector{Float64})
    nq = length(x_i)
    yq = Vector{Float64}(undef, nq)

    for iq in 1:nq
        y_low = y[x_i[iq]+1]
        y_high = y[x_i[iq]+2]
        yq[iq] = x_pi[iq] * y_low + (1-x_pi[iq]) * y_high
    end

    return yq
end

# STILL NEED DEBUG
function interpolate_point(x, x0, x1, y0, y1)
    y = y0 + (x - x0) * (yq - y0) / (x1 - x0)
    return y
end
